import itertools

import numpy as np
import scipy
import torch
from IPython import embed

import ipdb
import pickle
import os

import scipy as sp

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Controller:

    def __init__(self, ):
        self.name = None

    def set_batch(self, batch):
        self.batch = batch

    def set_batch_numpy_vec(self, batch):
        self.set_batch(batch)

    def set_env(self, env):
        self.env = env
    
    def set_name(self, name):
        self.name = name
    
    def get_name(self, ):
        return self.name


class OptPolicy(Controller):
    def __init__(self, env, batch_size=1):
        super().__init__()
        self.env = env
        self.batch_size = batch_size

    def reset(self):
        return

    def act(self, x):
        return self.env.opt_a


    def act_numpy_vec(self, x):
        opt_as = [ env.opt_a for env in self.env ]

        ################### Run below for multiple models from bash script (this is for analysis_dist.py) #################
        # ## uncommnet to print the opt_a
        # with open("script/online/data/opt_a.pkl", "ab+") as f:
        #     pickle.dump(np.stack(opt_as, axis=0), f)
        # f.close()

        # ## uncommnet to print the opt_a
        # with open("script/online/data_larger/opt_a.pkl", "ab+") as f:
        #     pickle.dump(np.stack(opt_as, axis=0), f)
        # f.close()

        # ## uncommnet to print the opt_a
        # with open("script/online/data_large/opt_a.pkl", "ab+") as f:
        #     pickle.dump(np.stack(opt_as, axis=0), f)
        # f.close()

        # ## uncommnet to print the opt_a
        # with open("script/new_arms/data/opt_a.pkl", "ab+") as f:
        #     pickle.dump(np.stack(opt_as, axis=0), f)
        # f.close()

        # ## uncommnet to print the opt_a
        # with open("script/new_arms/data_large/opt_a.pkl", "ab+") as f:
        #     pickle.dump(np.stack(opt_as, axis=0), f)
        # f.close()


        # ## uncommnet to print the opt_a
        # with open("script/new_arms_final/data/opt_a.pkl", "ab+") as f:
        #     pickle.dump(np.stack(opt_as, axis=0), f)
        # f.close()

        # ## uncommnet to print the opt_a
        # with open("script/new_arms_final/data_large/opt_a.pkl", "ab+") as f:
        #     pickle.dump(np.stack(opt_as, axis=0), f)
        # f.close()

        return np.stack(opt_as, axis=0)
        # return np.tile(self.env.opt_a, (self.batch_size, 1))


class GreedyOptPolicy(Controller):
    def __init__(self, env):
        super().__init__()
        self.env = env

    def reset(self):
        return

    def act(self, x):
        rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()
        i = np.argmax(rewards)
        a = self.batch['context_actions'].cpu().detach().numpy()[0][i]
        self.a = a
        return self.a


class EmpMeanPolicy(Controller):
    def __init__(self, env, online=False, batch_size = 1):
        super().__init__()
        self.env = env
        self.online = online
        self.batch_size = batch_size

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['context_actions'].cpu().detach().numpy()[0]
        rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        i = np.argmax(b_mean)
        j = np.argmin(counts)
        if self.online and counts[j] == 0:
            i = j
        
        a = np.zeros(self.env.dim)
        a[i] = 1.0

        self.a = a
        return self.a

    def act_numpy_vec(self, x):
        actions = self.batch['context_actions']
        rewards = self.batch['context_rewards']

        b = np.zeros((self.batch_size, self.env.dim))
        counts = np.zeros((self.batch_size, self.env.dim))
        action_indices = np.argmax(actions, axis=-1)
        for idx in range(self.batch_size):
            actions_idx = action_indices[idx]
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                b[idx, c] = np.sum(arm_rewards)
                counts[idx, c] = len(arm_rewards)

        b_mean = b / np.maximum(1, counts)

        i = np.argmax(b_mean, axis=-1)
        j = np.argmin(counts, axis=-1)
        if self.online:
            mask = (counts[np.arange(self.batch_size), j] == 0)
            i[mask] = j[mask]

        a = np.zeros((self.batch_size, self.env.dim))
        a[np.arange(self.batch_size), i] = 1.0

        self.a = a
        return self.a



class ThompsonSamplingPolicy(Controller):
    def __init__(self, env, std=.1, sample=False, prior_mean=.5, prior_var=1/12.0, warm_start=False, batch_size=1):
        super().__init__()
        self.env = env
        self.variance = std**2
        self.prior_mean = prior_mean
        self.prior_variance = prior_var
        self.batch_size = batch_size

        self.reset()
        self.sample = sample
        self.warm_start = warm_start

    def reset(self):
        if self.batch_size > 1:
            self.means = np.ones((self.batch_size, self.env.dim)) * self.prior_mean
            self.variances = np.ones((self.batch_size, self.env.dim)) * self.prior_variance
            self.counts = np.zeros((self.batch_size, self.env.dim))
        else:
            self.means = np.ones(self.env.dim) * self.prior_mean
            self.variances = np.ones(self.env.dim) * self.prior_variance
            self.counts = np.zeros(self.env.dim)

    def set_batch(self, batch):
        self.reset()
        self.batch = batch
        actions = self.batch['context_actions'].cpu().detach().numpy()[0]
        rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()

        for i in range(len(actions)):
            c = np.argmax(actions[i])
            self.counts[c] += 1

        for c in range(self.env.dim):
            arm_rewards = rewards[np.argmax(actions, axis=1) == c]
            self.update_posterior(c, arm_rewards)

    def set_batch_numpy_vec(self, batch):
        self.reset()
        self.batch = batch
        actions = self.batch['context_actions']
        rewards = self.batch['context_rewards'][:, :, 0]

        for i in range(len(actions[0])):
            c = np.argmax(actions[:, i], axis=-1)
            self.counts[np.arange(self.batch_size), c] += 1

        arm_means = np.zeros((self.batch_size, self.env.dim))
        for idx in range(self.batch_size):
            actions_idx = np.argmax(actions[idx], axis=-1)
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                if self.counts[idx, c] > 0:
                    arm_mean = np.mean(arm_rewards)
                    arm_means[idx, c] = arm_mean

        assert arm_means.shape[0] == self.batch_size
        assert arm_means.shape[1] == self.env.dim

        self.update_posterior_all(arm_means)

    def update_posterior(self, c, arm_rewards):
        n = self.counts[c]

        if n > 0:
            arm_mean = np.mean(arm_rewards)
            prior_weight = self.variance / (self.variance + (n * self.prior_variance))
            new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_mean
            new_variance = 1 / (1 / self.prior_variance + n / self.variance)

            self.means[c] = new_mean
            self.variances[c] = new_variance

    def update_posterior_all(self, arm_means):
        prior_weight = self.variance / (self.variance + (self.counts * self.prior_variance))
        new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_means
        new_variance = 1 / (1 / self.prior_variance + self.counts / self.variance)

        mask = (self.counts > 0)
        self.means[mask] = new_mean[mask]
        self.variances[mask] = new_variance[mask]

    def act(self, x):
        if self.sample:
            values = np.random.normal(self.means, np.sqrt(self.variances))
            i = np.argmax(values)

            actions = self.batch['context_actions'].cpu().detach().numpy()[0]
            rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()

            if self.warm_start:
                counts = np.zeros(self.env.dim)
                for j in range(len(actions)):
                    c = np.argmax(actions[j])
                    counts[c] += 1
                j = np.argmin(counts)
                if counts[j] == 0:
                    i = j
        else:
            values = np.random.normal(self.means, np.sqrt(self.variances), size=(100, self.env.dim))
            amax = np.argmax(values, axis=1)
            freqs = np.bincount(amax, minlength=self.env.dim)
            i = np.argmax(freqs)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a

        return self.a

    def act_numpy_vec(self, x):
        if self.sample:
            values = np.random.normal(self.means, np.sqrt(self.variances))
            action_indices = np.argmax(values, axis=-1)

            actions = self.batch['context_actions']
            rewards = self.batch['context_rewards']

        else:
            values = np.stack([
                np.random.normal(self.means, np.sqrt(self.variances))
                for _ in range(100)], axis=1)
            amax = np.argmax(values, axis=-1)
            freqs = np.array([np.bincount(am, minlength=self.env.dim) for am in amax])
            action_indices = np.argmax(freqs, axis=-1)

        actions = np.zeros((self.batch_size, self.env.dim))
        actions[np.arange(self.batch_size), action_indices] = 1.0
        self.a = actions

        # print("actions", actions, np.shape(actions))
        return self.a



class ThompsonSamplingPolicy_greedy(Controller):
    def __init__(self, env, std=.1, sample=False, prior_mean=.5, prior_var=1/12.0, warm_start=False, batch_size=1):
        super().__init__()
        self.env = env
        self.variance = std**2
        self.prior_mean = prior_mean
        self.prior_variance = prior_var
        self.batch_size = batch_size

        self.reset()
        self.sample = sample
        self.warm_start = warm_start

    def reset(self):
        if self.batch_size > 1:
            self.means = np.ones((self.batch_size, self.env.dim)) * self.prior_mean
            self.variances = np.ones((self.batch_size, self.env.dim)) * self.prior_variance
            self.counts = np.zeros((self.batch_size, self.env.dim))
        else:
            self.means = np.ones(self.env.dim) * self.prior_mean
            self.variances = np.ones(self.env.dim) * self.prior_variance
            self.counts = np.zeros(self.env.dim)

    def set_batch(self, batch):
        self.reset()
        self.batch = batch
        actions = self.batch['context_actions'].cpu().detach().numpy()[0]
        rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()

        for i in range(len(actions)):
            c = np.argmax(actions[i])
            self.counts[c] += 1

        for c in range(self.env.dim):
            arm_rewards = rewards[np.argmax(actions, axis=1) == c]
            self.update_posterior(c, arm_rewards)

    def set_batch_numpy_vec(self, batch):
        self.reset()
        self.batch = batch
        actions = self.batch['context_actions']
        rewards = self.batch['context_rewards'][:, :, 0]

        for i in range(len(actions[0])):
            c = np.argmax(actions[:, i], axis=-1)
            self.counts[np.arange(self.batch_size), c] += 1

        arm_means = np.zeros((self.batch_size, self.env.dim))
        for idx in range(self.batch_size):
            actions_idx = np.argmax(actions[idx], axis=-1)
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                if self.counts[idx, c] > 0:
                    arm_mean = np.mean(arm_rewards)
                    arm_means[idx, c] = arm_mean

        assert arm_means.shape[0] == self.batch_size
        assert arm_means.shape[1] == self.env.dim

        self.update_posterior_all(arm_means)

    def update_posterior(self, c, arm_rewards):
        n = self.counts[c]

        if n > 0:
            arm_mean = np.mean(arm_rewards)
            prior_weight = self.variance / (self.variance + (n * self.prior_variance))
            new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_mean
            new_variance = 1 / (1 / self.prior_variance + n / self.variance)

            self.means[c] = new_mean
            self.variances[c] = new_variance

    def update_posterior_all(self, arm_means):
        prior_weight = self.variance / (self.variance + (self.counts * self.prior_variance))
        new_mean = prior_weight * self.prior_mean + (1 - prior_weight) * arm_means
        new_variance = 1 / (1 / self.prior_variance + self.counts / self.variance)

        mask = (self.counts > 0)
        self.means[mask] = new_mean[mask]
        self.variances[mask] = new_variance[mask]

    def act(self, x):
        if self.sample:
            values = np.random.normal(self.means, np.sqrt(self.variances))
            i = np.argmax(values)

            actions = self.batch['context_actions'].cpu().detach().numpy()[0]
            rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()

            if self.warm_start:
                counts = np.zeros(self.env.dim)
                for j in range(len(actions)):
                    c = np.argmax(actions[j])
                    counts[c] += 1
                j = np.argmin(counts)
                if counts[j] == 0:
                    i = j
        else:
            values = np.random.normal(self.means, np.sqrt(self.variances), size=(100, self.env.dim))
            amax = np.argmax(values, axis=1)
            freqs = np.bincount(amax, minlength=self.env.dim)
            i = np.argmax(freqs)
        
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a

        return self.a

    def act_numpy_vec(self, x):

        self.epsilon = 0.1

        epsilon = 0.1
        if np.random.random() < epsilon:
            action_indices = np.argmax(self.means, axis=-1)
            # ipdb.set_trace()
        else:
            if self.sample:
                values = np.random.normal(self.means, np.sqrt(self.variances))
                action_indices = np.argmax(values, axis=-1)

                actions = self.batch['context_actions']
                rewards = self.batch['context_rewards']

            else:
                values = np.stack([
                    np.random.normal(self.means, np.sqrt(self.variances))
                    for _ in range(100)], axis=1)
                amax = np.argmax(values, axis=-1)
                freqs = np.array([np.bincount(am, minlength=self.env.dim) for am in amax])
                action_indices = np.argmax(freqs, axis=-1)

        actions = np.zeros((self.batch_size, self.env.dim))
        actions[np.arange(self.batch_size), action_indices] = 1.0
        self.a = actions

        # print("actions", actions, np.shape(actions))
        return self.a



class PessMeanPolicy(Controller):
    def __init__(self, env, const=1.0, batch_size=1):
        super().__init__()
        self.env = env
        self.const = const
        self.batch_size = batch_size

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['context_actions'].cpu().detach().numpy()[0]
        rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        # compute the square root of the counts but clip so it's at least one
        pens = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean - pens

        i = np.argmax(bounds)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a
        return self.a


    def act_numpy_vec(self, x):
        actions = self.batch['context_actions']
        rewards = self.batch['context_rewards']

        b = np.zeros((self.batch_size, self.env.dim))
        counts = np.zeros((self.batch_size, self.env.dim))
        action_indices = np.argmax(actions, axis=-1)
        for idx in range(self.batch_size):
            actions_idx = action_indices[idx]
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                b[idx, c] = np.sum(arm_rewards)
                counts[idx, c] = len(arm_rewards)

        b_mean = b / np.maximum(1, counts)

        # compute the square root of the counts but clip so it's at least one
        bons = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean - bons

        i = np.argmax(bounds, axis=-1)
        a = np.zeros((self.batch_size, self.env.dim))
        a[np.arange(self.batch_size), i] = 1.0
        self.a = a
        return self.a



class UCBPolicy(Controller):
    def __init__(self, env, const=1.0, batch_size=1):
        super().__init__()
        self.env = env
        self.const = const
        self.batch_size = batch_size

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['context_actions'].cpu().detach().numpy()[0]
        rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()

        b = np.zeros(self.env.dim)
        counts = np.zeros(self.env.dim)
        for i in range(len(actions)):
            c = np.argmax(actions[i])
            b[c] += rewards[i]
            counts[c] += 1

        b_mean = b / np.maximum(1, counts)

        # compute the square root of the counts but clip so it's at least one
        bons = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean + bons

        i = np.argmax(bounds)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a
        return self.a

    def act_numpy_vec(self, x):
        actions = self.batch['context_actions']
        rewards = self.batch['context_rewards']

        b = np.zeros((self.batch_size, self.env.dim))
        counts = np.zeros((self.batch_size, self.env.dim))
        action_indices = np.argmax(actions, axis=-1)
        for idx in range(self.batch_size):
            actions_idx = action_indices[idx]
            rewards_idx = rewards[idx]
            for c in range(self.env.dim):
                arm_rewards = rewards_idx[actions_idx == c]
                b[idx, c] = np.sum(arm_rewards)
                counts[idx, c] = len(arm_rewards)

        b_mean = b / np.maximum(1, counts)

        # compute the square root of the counts but clip so it's at least one
        bons = self.const / np.maximum(1, np.sqrt(counts))
        bounds = b_mean + bons

        i = np.argmax(bounds, axis=-1)
        # j = np.argmin(counts, axis=-1)
        # mask = (counts[np.arange(self.batch_size), j] == 0)
        # i[mask] = j[mask]

        # print("\n",np.argmax(self.env.means))
        # print(i)

        a = np.zeros((self.batch_size, self.env.dim))
        a[np.arange(self.batch_size), i] = 1.0
        self.a = a

        # print(self.a)
        return self.a




class UnifPolicy(Controller):
    def __init__(self, env, const=1.0, batch_size=1):
        super().__init__()
        self.env = env
        self.const = const
        self.batch_size = batch_size

    def reset(self):
        return

    def act(self, x):
        actions = self.batch['context_actions'].cpu().detach().numpy()[0]
        rewards = self.batch['context_rewards'].cpu().detach().numpy().flatten()

        # b = np.zeros(self.env.dim)
        # counts = np.zeros(self.env.dim)
        # for i in range(len(actions)):
        #     c = np.argmax(actions[i])
        #     b[c] += rewards[i]
        #     counts[c] += 1

        # b_mean = b / np.maximum(1, counts)

        # # compute the square root of the counts but clip so it's at least one
        # bons = self.const / np.maximum(1, np.sqrt(counts))
        # bounds = b_mean + bons

        i = np.random.choice(self.env.K)
        a = np.zeros(self.env.dim)
        a[i] = 1.0
        self.a = a
        return self.a

    # def act_numpy_vec(self, x):
    #     actions = self.batch['context_actions']
    #     rewards = self.batch['context_rewards']

    #     b = np.zeros((self.batch_size, self.env.dim))
    #     counts = np.zeros((self.batch_size, self.env.dim))
    #     action_indices = np.argmax(actions, axis=-1)
    #     for idx in range(self.batch_size):
    #         actions_idx = action_indices[idx]
    #         rewards_idx = rewards[idx]
    #         for c in range(self.env.dim):
    #             arm_rewards = rewards_idx[actions_idx == c]
    #             b[idx, c] = np.sum(arm_rewards)
    #             counts[idx, c] = len(arm_rewards)

    #     b_mean = b / np.maximum(1, counts)

    #     # compute the square root of the counts but clip so it's at least one
    #     bons = self.const / np.maximum(1, np.sqrt(counts))
    #     bounds = b_mean + bons

    #     i = np.argmax(bounds, axis=-1)
    #     j = np.argmin(counts, axis=-1)


    #     mask = (counts[np.arange(200), j] == 0)
    #     i[mask] = j[mask]

    #     a = np.zeros((self.batch_size, self.env.dim))
    #     a[np.arange(self.batch_size), i] = 1.0
    #     self.a = a

    #     print(self.a, np.shape(a))
    #     exit()
    #     return self.a

    def act_numpy_vec(self, x):
        actions = self.batch['context_actions']
        rewards = self.batch['context_rewards']

        # b = np.zeros((self.batch_size, self.env.dim))
        # counts = np.zeros((self.batch_size, self.env.dim))
        # action_indices = np.argmax(actions, axis=-1)
        # for idx in range(self.batch_size):
        #     actions_idx = action_indices[idx]
        #     rewards_idx = rewards[idx]
        #     for c in range(self.env.dim):
        #         arm_rewards = rewards_idx[actions_idx == c]
        #         b[idx, c] = np.sum(arm_rewards)
        #         counts[idx, c] = len(arm_rewards)

        # b_mean = b / np.maximum(1, counts)

        # # compute the square root of the counts but clip so it's at least one
        # bons = self.const / np.maximum(1, np.sqrt(counts))
        # bounds = b_mean + bons

        # i = np.argmax(bounds, axis=-1)
        # j = np.argmin(counts, axis=-1)
        # mask = (counts[np.arange(200), j] == 0)
        # i[mask] = j[mask]

        i = np.random.choice(self.env.dim, size=(self.batch_size, self.env.dim))
        # print(i)
        i_a = np.argmax(i, axis=-1)
        a = np.zeros((self.batch_size, self.env.dim))
        a[np.arange(self.batch_size), i_a] = 1.0
        # print("a", a, np.shape(a))
        # exit(0)
        self.a = a
        return self.a



class BanditTransformerController(Controller):
    def __init__(self, model, sample=False,  batch_size=1, tf_type = "original"):
        self.model = model
        self.du = model.config['action_dim']
        self.dx = model.config['state_dim']
        self.H = model.horizon
        self.sample = sample
        self.batch_size = batch_size
        self.zeros = torch.zeros(batch_size, self.dx**2 + self.du + 1).float().to(device)
        self.tf_type = tf_type

        # self.emp_means = np.zeros((batch_size, self.du))

    def set_env(self, env):
        return

    def set_batch_numpy_vec(self, batch):
        # Convert each element of the batch to a torch tensor
        new_batch = {}
        for key in batch.keys():
            new_batch[key] = torch.tensor(batch[key]).float().to(device)
        self.set_batch(new_batch)

    def act(self, x):
        self.batch['zeros'] = self.zeros

        states = torch.tensor(x)[None, :].float().to(device)
        self.batch['query_states'] = states

        if self.tf_type == "original":
            a = self.model(self.batch)
        elif self.tf_type == "new":
            a_next, r_pred = self.model(self.batch)
            epsilon = 0.5
            if np.random.random() < epsilon:
                a = a_next
                self.sample = True
            else:
                a = r_pred
                self.sample = False
        elif self.tf_type == "new_opt_a":
            a, _, _ = self.model(self.batch)
        else:
            raise NotImplementedError

        # a = self.model(self.batch)
        a = a.cpu().detach().numpy()[0]

        if self.sample:
            probs = scipy.special.softmax(a)
            # print("probs", probs)
            i = np.random.choice(np.arange(self.du), p=probs)
        else:
            i = np.argmax(a)

        a = np.zeros(self.du)
        a[i] = 1.0
        return a

    def act_numpy_vec(self, x):
        self.batch['zeros'] = self.zeros

        states = torch.tensor(np.array(x))
        if self.batch_size == 1:
            states = states[None,:]
        states = states.float().to(device)
        self.batch['query_states'] = states

        
        # a = self.model(self.batch)

        if self.tf_type == "original":
            a = self.model(self.batch)
            # ipdb.set_trace()
            self.sample = True
        elif self.tf_type == "new":

            
            a_next, r_pred = self.model(self.batch)

            # ipdb.set_trace()

            # # Predict reward of next action, then compare with predicted reward of bext reward action
            # a = torch.zeros_like(a_next)
            # r_pred_ = r_pred.cpu().detach().numpy()
            # a_next_ = a_next.cpu().detach().numpy()

            # size_ = a_next.shape[0]
            # for i in range(size_):
            #     next_action = np.argmax(a_next_[i], axis=-1)
            #     if r_pred_[i, next_action] < 0.01:
            #         # a[i] = r_pred[i]
            #         a[i] = r_pred[i]
            #         self.sample = False
            #     else:
            #         a[i] = a_next[i]
            #         self.sample = True
            


            ########################### Run this files with single model runs from bash script ##########
            # ##### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####
            # with open("script/run_without_head/r_pred_without_head_pred.pkl", "ab") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####
            # with open("script/online/data/r_pred_without_head_pred.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()


            # ### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####
            # with open("script/online/data_larger/r_pred_without_head_pred.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####
            # with open("script/online/data_large/r_pred_without_head_pred.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ## Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/online/data/r_pred_without_head_pred_tau.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ## Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/online/data_larger/r_pred_without_head_pred_tau.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()


            # #### Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/online/data_large/r_pred_without_head_pred.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ## Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/online/data_large/r_pred_without_head_pred_tau.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()




            # #### Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/new_arms/data/r_pred_without_head_pred.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ## Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/new_arms/data/r_pred_without_head_pred_tau.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()



            # #### Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/new_arms/data_large/r_pred_without_head_pred.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ## Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/new_arms/data_large/r_pred_without_head_pred_tau.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()




            # #### Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/new_arms_final/data/r_pred_without_head_pred.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ## Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/new_arms_final/data/r_pred_without_head_pred_tau.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()


            # #### Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/new_arms_final/data_large/r_pred_without_head_pred.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()

            # ## Uncomment to save the predicted reward, check file name carefully-, turn sample = True in eval linear bandit #####
            # with open("script/new_arms_final/data_large/r_pred_without_head_pred_tau.pkl", "ab+") as f:
            #     r_pred_ = r_pred.cpu().detach().numpy()
            #     pickle.dump(r_pred_, f)
            # f.close()
            
            epsilon = 0.0
            if np.random.random() < epsilon:
                a = a_next
                # self.sample = True
            else:
                a = r_pred
                # self.sample = False

                ##### Uncomment to explore via reward distribution #############
                # a = torch.exp(r_pred/0.25) 
                # self.sample = True
                
            
        elif self.tf_type == "new_opt_a":
            a_next, r_pred, a_opt = self.model(self.batch)
            
            a = torch.zeros_like(a_next)
            a_opt_ = a_opt.cpu().detach().numpy()
            r_pred_ = r_pred.cpu().detach().numpy()
            
            size_ = a_opt_.shape[0]
            for i in range(size_):
                if np.argmax(a_opt_[i], axis=-1) == np.argmax(r_pred_[i], axis=-1):
                    a[i] = a_opt[i]
                    self.sample = False
                else:
                    a[i] = a_next[i]
                    self.sample = False
            
            
            # ipdb.set_trace()
            # epsilon = 0.1
            # if np.random.random() < epsilon:
            #     a = a_next
            #     self.sample = False
            # else:
            #     a = r_pred
            #     self.sample = False
        else:
            raise NotImplementedError
            
        a = a.cpu().detach().numpy()
        if self.batch_size == 1:
            a = a[0]

        if self.sample:
            # temperature = 0.05 # linear
            # temperature = 0.001 # non-linear
            # temperature = 0.05 # new arms linear
            # temperature = 0.005 # new arms non-linear
            temperature = 0.05 # default
            # temperature = 0.08
            probs = scipy.special.softmax(a/temperature, axis=-1)
            # print("probs", probs)
            action_indices = np.array([np.random.choice(np.arange(self.du), p=p) for p in probs])
        else:
            action_indices = np.argmax(a, axis=-1)

        actions = np.zeros((self.batch_size, self.du))
        actions[np.arange(self.batch_size), action_indices] = 1.0

        ###################### Run below for multiple models runs from bash script #################
        # # #### Uncomment to save the predicted reward, and actions, check file name carefully->  #####
        # with open("script/online/data/a_pred.pkl", "ab+") as f:
        #     pickle.dump(actions, f)
        # f.close()

        #### Uncomment to save the predicted reward, and actions, check file name carefully->  #####
        # with open("script/online/data_larger/a_pred.pkl", "ab+") as f:
        #     pickle.dump(actions, f)
        # f.close()

        # ### Uncomment to save the predicted reward, and actions, check file name carefully->  #####
        # with open("script/online/data_large/a_pred.pkl", "ab+") as f:
        #     pickle.dump(actions, f)
        # f.close()


        # ### Uncomment to save the predicted reward for new arms linear small, and actions, check file name carefully->  #####
        # with open("script/new_arms/data/a_pred.pkl", "ab+") as f:
        #     pickle.dump(actions, f)
        # f.close()


        # ### Uncomment to save the predicted reward for new arms linear large, and actions, check file name carefully->  #####
        # with open("script/new_arms/data_large/a_pred.pkl", "ab+") as f:
        #     pickle.dump(actions, f)
        # f.close()


        # ### Uncomment to save the predicted reward for new arms linear large, and actions, check file name carefully->  #####
        # with open("script/new_arms_final/data/a_pred.pkl", "ab+") as f:
        #     pickle.dump(actions, f)
        # f.close()

        # ### Uncomment to save the predicted reward for new arms linear large, and actions, check file name carefully->  #####
        # with open("script/new_arms_final/data_large/a_pred.pkl", "ab+") as f:
        #     pickle.dump(actions, f)
        # f.close()
        
        return actions


def softmax(arr):
    # Subtract the maximum value for numerical stability
    exp_arr = np.exp(arr - np.max(arr))
    return exp_arr / exp_arr.sum()


class LinUCBPolicy(OptPolicy):
    def __init__(self, env, const=1.0, batch_size=1):
        super().__init__(env)
        self.rand = True
        self.const = const
        self.arms = env.arms
        self.d = self.arms.shape[1]
        self.dim = env.dim
        self.theta = np.zeros(self.d)
        self.init_cov = 1.0 * np.eye(self.d)
        self.batch_size = batch_size

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1:
            i = np.random.choice(np.arange(self.dim))
            hot_vector = np.zeros(self.dim)
            hot_vector[i] = 1
            return hot_vector

        else:
            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards

            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            self.a = hot_vector
            return hot_vector

    def act_numpy_vec(self, x):
        # TODO: parallelize this later
        actions_batch = self.batch['context_actions']
        rewards_batch = self.batch['context_rewards']

        if len(rewards_batch[0]) < 1:
            indices = np.random.choice(np.arange(self.dim), size=self.batch_size)
            hot_vectors = np.zeros((self.batch_size, self.dim))
            hot_vectors[np.arange(self.batch_size), indices] = 1
            return hot_vectors

        hot_vectors = []

        value_means = np.zeros((self.batch_size, self.dim)) # to save the predicted mean value in a batch
        m = 0
        temperature = 0.05
        for i in range(self.batch_size):
            actions = actions_batch[i]
            rewards = rewards_batch[i]
            
            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards
            theta = theta.flatten()

            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i
                
                value_means[m][i] = theta @ arm

            best_arm_index = np.random.choice(np.arange(self.dim), p=softmax(value_means[m]/temperature)) #### Uncomment this line if running soft LinUCB

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            hot_vectors.append(hot_vector)

            m += 1


        ###################### Run below for single models runs from bash script #################
        # #### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/run_without_head/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()


        # ### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/online/data/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()


        # ### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/online/data_larger/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()

        ### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/online/data_large/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()


        # ## Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/new_arms/data/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()


        # ## Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/new_arms/data_large/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()


        # ## Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/new_arms_final/data/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()


        # ## Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/new_arms_final/data_large/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()

        return np.array(hot_vectors)



################# LinUCB for bilienar #################

class LinUCBPolicy_Bilinear(OptPolicy):
    def __init__(self, env, const=1.0, batch_size=1):
        super().__init__(env)

        print("LinUCBPolicy_Bilinear")
        self.rand = True
        self.const = const
        # self.arms = np.concatenate((env.arms_left, env.arms_right), axis=1)
        self.arms = [np.outer(env.arms_left[i], env.arms_right[i]) for i in range(env.dim)]
        self.arms = np.array(self.arms) 
        _, dim1, dim2 = self.arms.shape
        self.arms = self.arms.reshape(-1, dim1*dim2)
        
        self.arm_left = env.arms_left
        self.arm_right = env.arms_right

        self.d = self.arms.shape[1]
        self.dim = env.dim
        self.theta = np.zeros(self.d)
        self.init_cov = 1.0 * np.eye(self.d)
        self.batch_size = batch_size

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1:
            i = np.random.choice(np.arange(self.dim))
            hot_vector = np.zeros(self.dim)
            hot_vector[i] = 1
            return hot_vector

        else:
            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards

            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            self.a = hot_vector
            return hot_vector

    def act_numpy_vec(self, x):
        # TODO: parallelize this later
        actions_batch = self.batch['context_actions']
        rewards_batch = self.batch['context_rewards']

        if len(rewards_batch[0]) < 1:
            indices = np.random.choice(np.arange(self.dim), size=self.batch_size)
            hot_vectors = np.zeros((self.batch_size, self.dim))
            hot_vectors[np.arange(self.batch_size), indices] = 1
            return hot_vectors

        hot_vectors = []

        value_means = np.zeros((self.batch_size, self.dim)) # to save the predicted mean value in a batch
        m = 0
        for i in range(self.batch_size):
            actions = actions_batch[i]
            rewards = rewards_batch[i]
            
            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards

            dim1, dim2 = theta.shape
            theta_matrix = theta.reshape(int(np.sqrt(self.d)), int(np.sqrt(self.d)))
            
            U, D, VT = np.linalg.svd(theta_matrix)
            rank = np.linalg.matrix_rank(theta_matrix)
            # U, D, VT = sp.sparse.linalg.svds(theta_matrix, k=1)

            

            # svd = decomposition.TruncatedSVD(n_components=5, n_iter=7, random_state=42)
            # svd.fit(X)
            # U, D, VT = svd.fit(hat_Theta)
            # ipdb.set_trace()

            new_arms_left = self.arm_left @ U
            new_arms_right = self.arm_right @ (VT.T)
            arms = [np.outer(new_arms_left[i], new_arms_right[i]) for i in range(self.dim)]
            arms = np.array(arms) 
            _, dim1, dim2 = arms.shape
            arms = arms.reshape(-1, dim1*dim2)

            # ipdb.set_trace()
            k = (dim1 + dim2)*rank
            # Lambda = np.eye(dim1*dim2)
            lambda_ = 0.05
            lambda_perp = 21/np.log(1 + (21/lambda_))

            diag_elements = []
            for d in range(0,k+1):
              diag_elements.append(lambda_)
            for d in range(k+1,dim1*dim2):
              diag_elements.append(lambda_perp)

            Lambda = np.diag(diag_elements)

            actions_arms = arms[actions_indices]

            # ipdb.set_trace()
            
            cov = self.init_cov + actions_arms.T @ actions_arms

            # ipdb.set_trace()
            if np.shape(cov) != np.shape(Lambda):
                cov_inv = np.linalg.inv(cov)
            else:
                cov_inv = np.linalg.inv(cov + Lambda)

            theta = cov_inv @ actions_arms.T @ rewards
            # print(theta)

            # self.rot_X = np.array([np.outer(self.new_X,self.new_Z[i]).flatten() for i in range(env.num_arms)])
            # rank = np.linalg.matrix_rank(theta_matrix)
            # print(rank)
            # ipdb.set_trace()
            theta = theta.flatten()


            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i
                
                value_means[m][i] = theta @ arm

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            hot_vectors.append(hot_vector)

            m += 1

        # #### Uncomment to save the predicted reward, check file name carefully-> without head/ with head #####      
        # with open("script/run_without_head/r_pred_without_head_linucb.pkl", "ab") as f:
        #     r_pred_ = value_means
        #     pickle.dump(r_pred_, f)
        # f.close()

        return np.array(hot_vectors)






class BayesPredictorPolicy_test(OptPolicy):
    def __init__(self, env, const=1.0, train_dataset = None, theta = None, batch_size=1):
        super().__init__(env)
        self.rand = True
        self.const = const
        self.arms = env.arms
        self.d = self.arms.shape[1]
        self.dim = env.dim
        self.theta = np.zeros(self.d)
        self.init_cov = 1.0 * np.eye(self.d)
        self.batch_size = batch_size

        self.train_dataset = train_dataset
        self.theta = theta

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1:
            i = np.random.choice(np.arange(self.dim))
            hot_vector = np.zeros(self.dim)
            hot_vector[i] = 1
            return hot_vector

        else:
            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards

            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            self.a = hot_vector
            return hot_vector

    def act_numpy_vec(self, x):
        # TODO: parallelize this later

        hot_vectors = []

        for env in range(len(self.train_dataset['context_rewards'])):
            X = []
            Y = []
            
            pulls = np.zeros(self.dim)
            one_traj_actions = self.train_dataset['context_actions'][env]
            one_traj_rewards = self.train_dataset['context_rewards'][env]

            if len(one_traj_rewards) < 1:
                indices = np.random.choice(np.arange(self.dim), size=self.batch_size)
                hot_vectors = np.zeros((self.batch_size, self.dim))
                hot_vectors[np.arange(self.batch_size), indices] = 1
                return hot_vectors

            actions_indices = np.argmax(one_traj_actions, axis=1)
            actions_arms = self.arms[actions_indices]
            
            X.append(actions_arms)
            Y.append(one_traj_rewards)

            a,b,c = np.shape(X)
            X = np.reshape(X, (a*b, c))
            Y = np.reshape(Y, (a*b, 1))
            # init_cov = 1.0 * np.eye(len(X))
            # init_cov = self.arms @ self.arms.T
            # cov = init_cov + X @ X.T

            # print("cov", np.shape(cov))
            # cov_inv = np.linalg.inv(cov)

            # theta = X.T @ cov_inv @ Y
            # theta = theta.flatten()

            A = np.eye(self.dim)
            R = np.zeros((self.dim, 1))
            D = np.zeros((self.dim, self.dim))
            for indices in range(len(one_traj_actions)):
                c = np.argmax(one_traj_actions[indices])
                D[c, c] += 1
                R[c] += one_traj_rewards[indices]
            
            # init_cov = 1.0 * np.eye(len(D))
            self.arms = self.arms/np.sqrt(self.d) 
            init_cov = self.arms @ self.arms.T + 0.001 * np.eye(len(D))

            # ipdb.set_trace()
            cov = A @ (init_cov + D) @ A.T
            cov_inv = np.linalg.inv(cov)
            mean_value = A @ A.T @ cov_inv @ R
            


            # # Sampling from posterior
            # hot_vector = np.zeros(self.dim)
            # for i, arm in enumerate(self.arms):
            #     hot_vector[i] = np.exp(theta @ arm) 
            
            # hot_vector = hot_vector/np.sum(hot_vector) # normalize to a distribution
            # best_arm_index = np.random.choice(np.arange(self.dim), p=hot_vector) # choose best arm using sampling
            # hot_vector_ = np.zeros(self.dim)
            # hot_vector_[best_arm_index] = 1
            # hot_vectors.append(hot_vector_)

            # ipdb.set_trace()

            # deterministic selection for best reward arm
            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                # value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                # value = theta @ arm 
                value = mean_value[i]
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            hot_vectors.append(hot_vector)
        
        return np.array(hot_vectors)


class BayesPredictorPolicy_test_S(OptPolicy):
    def __init__(self, env, const=1.0, train_dataset = None, theta = None, batch_size=1):
        super().__init__(env)
        self.rand = True
        self.const = const
        self.arms = env.arms
        self.d = self.arms.shape[1]
        self.dim = env.dim
        self.theta = np.zeros(self.d)
        self.init_cov = 1.0 * np.eye(self.d)
        self.batch_size = batch_size

        self.train_dataset = train_dataset
        self.theta = theta

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1:
            i = np.random.choice(np.arange(self.dim))
            hot_vector = np.zeros(self.dim)
            hot_vector[i] = 1
            return hot_vector

        else:
            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards

            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            self.a = hot_vector
            return hot_vector

    def act_numpy_vec(self, x):
        # TODO: parallelize this later

        hot_vectors = []
        avg_reward_env = np.zeros(self.dim)
        sum_reward_env = np.zeros(self.dim)

        for env in range(len(self.train_dataset['context_rewards'])):
            one_traj_actions = self.train_dataset['context_actions'][env]
            one_traj_rewards = self.train_dataset['context_rewards'][env]

            actions_indices = np.argmax(one_traj_actions, axis=1)
            actions_arms = self.arms[actions_indices]
            
            for indices in range(len(one_traj_actions)):
                c = np.argmax(one_traj_actions[indices])
                sum_reward_env[c] += one_traj_rewards[indices]
        

        avg_reward_env = sum_reward_env/len(self.train_dataset['context_rewards'])
        init_cov = avg_reward_env @ avg_reward_env.T + 0.001 * np.eye(len(avg_reward_env))

        for env in range(len(self.train_dataset['context_rewards'])):
            X = []
            Y = []
            
            pulls = np.zeros(self.dim)
            one_traj_actions = self.train_dataset['context_actions'][env]
            one_traj_rewards = self.train_dataset['context_rewards'][env]

            if len(one_traj_rewards) < 1:
                indices = np.random.choice(np.arange(self.dim), size=self.batch_size)
                hot_vectors = np.zeros((self.batch_size, self.dim))
                hot_vectors[np.arange(self.batch_size), indices] = 1
                return hot_vectors

            actions_indices = np.argmax(one_traj_actions, axis=1)
            actions_arms = self.arms[actions_indices]
            
            X.append(actions_arms)
            Y.append(one_traj_rewards)

            a,b,c = np.shape(X)
            X = np.reshape(X, (a*b, c))
            Y = np.reshape(Y, (a*b, 1))
            # init_cov = 1.0 * np.eye(len(X))
            # init_cov = self.arms @ self.arms.T
            # cov = init_cov + X @ X.T

            # print("cov", np.shape(cov))
            # cov_inv = np.linalg.inv(cov)

            # theta = X.T @ cov_inv @ Y
            # theta = theta.flatten()

            A = np.eye(self.dim)
            R = np.zeros((self.dim, 1))
            D = np.zeros((self.dim, self.dim))
            for indices in range(len(one_traj_actions)):
                c = np.argmax(one_traj_actions[indices])
                D[c, c] += 1
                R[c] += one_traj_rewards[indices]
            
            # init_cov = 1.0 * np.eye(len(D))
            # self.arms = self.arms/np.sqrt(self.d) 
            # init_cov = self.arms @ self.arms.T + 0.001 * np.eye(len(D))

            # ipdb.set_trace()
            cov = A @ (init_cov + D) @ A.T
            cov_inv = np.linalg.inv(cov)
            mean_value = A @ A.T @ cov_inv @ R
            


            # # Sampling from posterior
            # hot_vector = np.zeros(self.dim)
            # for i, arm in enumerate(self.arms):
            #     hot_vector[i] = np.exp(theta @ arm) 
            
            # hot_vector = hot_vector/np.sum(hot_vector) # normalize to a distribution
            # best_arm_index = np.random.choice(np.arange(self.dim), p=hot_vector) # choose best arm using sampling
            # hot_vector_ = np.zeros(self.dim)
            # hot_vector_[best_arm_index] = 1
            # hot_vectors.append(hot_vector_)

            # ipdb.set_trace()

            # deterministic selection for best reward arm
            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                # value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                # value = theta @ arm 
                value = mean_value[i]
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            hot_vectors.append(hot_vector)
        
        return np.array(hot_vectors)



class BayesPredictorPolicy_train(OptPolicy):
    def __init__(self, env, const=1.0, train_dataset = None, theta = None, batch_size=1):
        super().__init__(env)
        self.rand = True
        self.const = const
        self.arms = env.arms
        self.d = self.arms.shape[1]
        self.dim = env.dim
        self.theta = np.zeros(self.d)
        self.init_cov = 1.0 * np.eye(self.d)
        self.batch_size = batch_size

        self.train_dataset = train_dataset
        self.theta = theta

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1:
            i = np.random.choice(np.arange(self.dim))
            hot_vector = np.zeros(self.dim)
            hot_vector[i] = 1
            return hot_vector

        else:
            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards

            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            self.a = hot_vector
            return hot_vector

    def act_numpy_vec(self, x):
        # TODO: parallelize this later
        
        train_batch_dataset = [{} for i in range(201)]
        ## This is to load original dataset usaing trainloader and batch
        params = {
            'batch_size': 100,
            'shuffle': True,
        }
        train_loader = torch.utils.data.DataLoader(self.train_dataset, **params)
        for i, batch in enumerate(train_loader):
            # print(f"Batch {i} of {len(train_loader)}", end='\r')
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # one_traj_actions = batch['context_actions'][0,:].cpu().detach().numpy()
            # one_traj_rewards = batch['context_actions'][0,:].cpu().detach().numpy()

            batch_traj_actions = batch['context_actions'].cpu().detach().numpy()
            batch_traj_rewards = batch['context_rewards'].cpu().detach().numpy()

            train_batch_dataset[i]['context_actions'] = batch_traj_actions
            train_batch_dataset[i]['context_rewards'] = batch_traj_rewards

            # print(i)
            if i > 199:
                break
        
        # ipdb.set_trace()

        


        hot_vectors = []

        # for env in range(len(train_batch_dataset['context_rewards'])):
        for env in range(200):
            
            # print("env eval for Bayes Pred", env)
            X = []
            Y = []
            
            # pulls = np.zeros(self.dim)
            batch_traj_actions = train_batch_dataset[env]['context_actions']
            batch_traj_rewards = train_batch_dataset[env]['context_rewards']

            for i in range(len(batch_traj_rewards)):
                one_traj_actions = batch_traj_actions[i]
                one_traj_rewards = batch_traj_rewards[i]

                actions_indices = np.argmax(one_traj_actions, axis=1)
                actions_arms = self.arms[actions_indices]
                
                X.append(actions_arms)
                Y.append(one_traj_rewards)


            # one_traj_actions = train_batch_dataset['context_actions'][env]
            # one_traj_rewards = train_batch_dataset['context_rewards'][env]

            # if len(one_traj_rewards) < 1:
            #     indices = np.random.choice(np.arange(self.dim), size=self.batch_size)
            #     hot_vectors = np.zeros((self.batch_size, self.dim))
            #     hot_vectors[np.arange(self.batch_size), indices] = 1
            #     return hot_vectors

            # actions_indices = np.argmax(one_traj_actions, axis=1)
            # actions_arms = self.arms[actions_indices]
            
            # X.append(actions_arms)
            # Y.append(one_traj_rewards)
            
            # ipdb.set_trace()

            a,b,c = np.shape(X)
            X = np.reshape(X, (a*b, c))
            Y = np.reshape(Y, (a*b, 1))
            init_cov = 1.0 * np.eye(len(X))
            cov = init_cov + X @ X.T

            # # print("cov", np.shape(cov))
            # cov_inv = np.linalg.inv(cov)

            # theta = X.T @ cov_inv @ Y
            # theta = theta.flatten()

            A = np.eye(self.dim)
            R = np.zeros((self.dim, 1))
            D = np.zeros((self.dim, self.dim))

            for i in range(len(batch_traj_rewards)):
                one_traj_actions = batch_traj_actions[i]
                for indices in range(len(one_traj_actions)):
                    c = np.argmax(one_traj_actions[indices])
                    D[c, c] += 1
                    R[c] += one_traj_rewards[indices]
            
            init_cov = 1.0 * np.eye(len(D))
            cov = A @ (init_cov + D) @ A.T
            cov_inv = np.linalg.inv(cov)
            mean_value = A @ A.T @ cov_inv @ R
            


            # # Sampling from posterior
            # hot_vector = np.zeros(self.dim)
            # for i, arm in enumerate(self.arms):
            #     hot_vector[i] = np.exp(theta @ arm) 
            
            # hot_vector = hot_vector/np.sum(hot_vector) # normalize to a distribution
            # best_arm_index = np.random.choice(np.arange(self.dim), p=hot_vector) # choose best arm using sampling
            # hot_vector_ = np.zeros(self.dim)
            # hot_vector_[best_arm_index] = 1
            # hot_vectors.append(hot_vector_)

            # ipdb.set_trace()

            # deterministic selection for best reward arm
            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                # value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                # value = theta @ arm 
                value = mean_value[i]
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            hot_vectors.append(hot_vector)
        
        return np.array(hot_vectors)


class BayesPredictorPolicy_train_S(OptPolicy):
    def __init__(self, env, const=1.0, train_dataset = None, uniform_dataset=None, theta = None, batch_size=1):
        super().__init__(env)
        self.rand = True
        self.const = const
        self.arms = env.arms
        self.d = self.arms.shape[1]
        self.dim = env.dim
        self.theta = np.zeros(self.d)
        self.init_cov = 1.0 * np.eye(self.d)
        self.batch_size = batch_size

        self.train_dataset = train_dataset
        self.theta = theta
        self.uniform_dataset = uniform_dataset

    def act(self, x):
        if len(self.batch['rollin_rs'][0]) < 1:
            i = np.random.choice(np.arange(self.dim))
            hot_vector = np.zeros(self.dim)
            hot_vector[i] = 1
            return hot_vector

        else:
            actions = self.batch['rollin_us'].cpu().detach().numpy()[0]
            rewards = self.batch['rollin_rs'].cpu().detach().numpy().flatten()

            actions_indices = np.argmax(actions, axis=1)
            actions_arms = self.arms[actions_indices]

            cov = self.init_cov + actions_arms.T @ actions_arms
            cov_inv = np.linalg.inv(cov)

            theta = cov_inv @ actions_arms.T @ rewards

            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            self.a = hot_vector
            return hot_vector

    def act_numpy_vec(self, x):
        # TODO: parallelize this later
        
        train_batch_dataset = [{} for i in range(201)]
        reward_env = np.zeros(self.dim)
        ## This is to load original dataset usaing trainloader and batch
        params = {
            'batch_size': 10,
            'shuffle': True,
        }
        train_loader = torch.utils.data.DataLoader(self.train_dataset, **params)
        for i, batch in enumerate(train_loader):
            # print(f"Batch {i} of {len(train_loader)}", end='\r')
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # one_traj_actions = batch['context_actions'][0,:].cpu().detach().numpy()
            # one_traj_rewards = batch['context_actions'][0,:].cpu().detach().numpy()

            batch_traj_actions = batch['context_actions'].cpu().detach().numpy()
            batch_traj_rewards = batch['context_rewards'].cpu().detach().numpy()

            train_batch_dataset[i]['context_actions'] = batch_traj_actions
            train_batch_dataset[i]['context_rewards'] = batch_traj_rewards

            # print(i)
            if i > 199:
                break
        
        # ipdb.set_trace()

        
        
        avg_reward_env = np.zeros(self.dim)
        sum_reward_env = np.zeros(self.dim)

        for env in range(200):
            
            # print("env eval for Bayes Pred", env)
            X = []
            Y = []
            
            # pulls = np.zeros(self.dim)
            batch_traj_actions = train_batch_dataset[env]['context_actions']
            batch_traj_rewards = train_batch_dataset[env]['context_rewards']

            for i in range(len(batch_traj_rewards)):
                one_traj_actions = batch_traj_actions[i]
                one_traj_rewards = batch_traj_rewards[i]

                actions_indices = np.argmax(one_traj_actions, axis=1)
                actions_arms = self.arms[actions_indices]

                for indices in range(len(one_traj_actions)):

        
                    c = np.argmax(one_traj_actions[indices])
                    sum_reward_env[c] += one_traj_rewards[indices]
        

        avg_reward_env = sum_reward_env/200
        init_cov = avg_reward_env @ avg_reward_env.T + 0.001 * np.eye(len(avg_reward_env))


        
        # avg_reward_env = np.zeros(self.dim)
        # sum_reward_env = np.zeros(self.dim)

        # for env in range(len(self.uniform_dataset['context_rewards'])):
        #     one_traj_actions = self.uniform_dataset['context_actions'][env]
        #     one_traj_rewards = self.uniform_dataset['context_rewards'][env]

        #     actions_indices = np.argmax(one_traj_actions, axis=1)
        #     actions_arms = self.arms[actions_indices]
            
        #     for indices in range(len(one_traj_actions)):
        #         c = np.argmax(one_traj_actions[indices])
        #         sum_reward_env[c] += one_traj_rewards[indices]
        

        # avg_reward_env = sum_reward_env/len(self.uniform_dataset['context_rewards'])
        # init_cov = avg_reward_env @ avg_reward_env.T + 0.001 * np.eye(len(avg_reward_env))


        hot_vectors = []

        
        for env in range(len(self.uniform_dataset['context_rewards'])):
            X = []
            Y = []
            
            pulls = np.zeros(self.dim)
            one_traj_actions = self.uniform_dataset['context_actions'][env]
            one_traj_rewards = self.uniform_dataset['context_rewards'][env]

            if len(one_traj_rewards) < 1:
                indices = np.random.choice(np.arange(self.dim), size=self.batch_size)
                hot_vectors = np.zeros((self.batch_size, self.dim))
                hot_vectors[np.arange(self.batch_size), indices] = 1
                return hot_vectors

            actions_indices = np.argmax(one_traj_actions, axis=1)
            actions_arms = self.arms[actions_indices]
            
            X.append(actions_arms)
            Y.append(one_traj_rewards)

            a,b,c = np.shape(X)
            X = np.reshape(X, (a*b, c))
            Y = np.reshape(Y, (a*b, 1))
            # init_cov = 1.0 * np.eye(len(X))
            # init_cov = self.arms @ self.arms.T
            # cov = init_cov + X @ X.T

            # print("cov", np.shape(cov))
            # cov_inv = np.linalg.inv(cov)

            # theta = X.T @ cov_inv @ Y
            # theta = theta.flatten()

            A = np.eye(self.dim)
            R = np.zeros((self.dim, 1))
            D = np.zeros((self.dim, self.dim))
            for indices in range(len(one_traj_actions)):
                c = np.argmax(one_traj_actions[indices])
                D[c, c] += 1
                R[c] += one_traj_rewards[indices]
            
            # init_cov = 1.0 * np.eye(len(D))
            # self.arms = self.arms/np.sqrt(self.d) 
            # init_cov = self.arms @ self.arms.T + 0.001 * np.eye(len(D))

            # ipdb.set_trace()
            cov = A @ (init_cov + D) @ A.T
            cov_inv = np.linalg.inv(cov)
            mean_value = A @ A.T @ cov_inv @ R
            


            # # Sampling from posterior
            # hot_vector = np.zeros(self.dim)
            # for i, arm in enumerate(self.arms):
            #     hot_vector[i] = np.exp(theta @ arm) 
            
            # hot_vector = hot_vector/np.sum(hot_vector) # normalize to a distribution
            # best_arm_index = np.random.choice(np.arange(self.dim), p=hot_vector) # choose best arm using sampling
            # hot_vector_ = np.zeros(self.dim)
            # hot_vector_[best_arm_index] = 1
            # hot_vectors.append(hot_vector_)

            # ipdb.set_trace()

            # deterministic selection for best reward arm
            best_arm_index = None
            best_value = -np.inf
            for i, arm in enumerate(self.arms):
                # value = theta @ arm + self.const * np.sqrt(arm @ cov_inv @ arm)
                # value = theta @ arm 
                value = mean_value[i]
                if value > best_value:
                    best_value = value
                    best_arm_index = i

            hot_vector = np.zeros(self.dim)
            hot_vector[best_arm_index] = 1
            hot_vectors.append(hot_vector)
        
        return np.array(hot_vectors)




################## Bandit Transformer that explores ##################
class BanditTransformerController_explore(Controller):
    def __init__(self, model, train_dataset, sample=False, uniform_dataset=None, batch_size=1, tf_type = "original"):
        self.model = model
        self.du = model.config['action_dim']
        self.dx = model.config['state_dim']
        self.H = model.horizon
        self.sample = sample
        self.batch_size = batch_size
        self.zeros = torch.zeros(batch_size, self.dx**2 + self.du + 1).float().to(device)
        self.tf_type = tf_type

        self.train_dataset = train_dataset

        # self.emp_means = np.zeros((batch_size, self.du))

    def set_env(self, env):
        return

    def set_batch_numpy_vec(self, batch):
        # Convert each element of the batch to a torch tensor
        new_batch = {}
        for key in batch.keys():
            new_batch[key] = torch.tensor(batch[key]).float().to(device)
        self.set_batch(new_batch)

    def act(self, x):
        self.batch['zeros'] = self.zeros

        states = torch.tensor(x)[None, :].float().to(device)
        self.batch['query_states'] = states

        if self.tf_type == "original":
            a = self.model(self.batch)
        elif self.tf_type == "new":
            a_next, r_pred = self.model(self.batch)
            epsilon = 0.5
            if np.random.random() < epsilon:
                a = a_next
                self.sample = True
            else:
                a = r_pred
                self.sample = False
        elif self.tf_type == "new_opt_a":
            a, _, _ = self.model(self.batch)
        else:
            raise NotImplementedError

        # a = self.model(self.batch)
        a = a.cpu().detach().numpy()[0]

        if self.sample:
            probs = scipy.special.softmax(a)
            # print("probs", probs)
            i = np.random.choice(np.arange(self.du), p=probs)
        else:
            i = np.argmax(a)

        a = np.zeros(self.du)
        a[i] = 1.0
        return a

    def act_numpy_vec(self, x):
        self.batch['zeros'] = self.zeros

        states = torch.tensor(np.array(x))
        if self.batch_size == 1:
            states = states[None,:]
        states = states.float().to(device)
        self.batch['query_states'] = states

        # a = self.model(self.batch)


        ### Calculate emp cov matrix ###
        train_batch_dataset = [{} for i in range(50)]
        reward_env = np.zeros(self.du)
        ## This is to load original dataset usaing trainloader and batch
        params = {
            'batch_size': 10,
            'shuffle': True,
        }

        
        train_loader = torch.utils.data.DataLoader(self.train_dataset, **params)
        for i, batch in enumerate(train_loader):
            # print(f"Batch {i} of {len(train_loader)}", end='\r')
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # one_traj_actions = batch['context_actions'][0,:].cpu().detach().numpy()
            # one_traj_rewards = batch['context_actions'][0,:].cpu().detach().numpy()

            batch_traj_actions = batch['context_actions'].cpu().detach().numpy()
            batch_traj_rewards = batch['context_rewards'].cpu().detach().numpy()

            train_batch_dataset[i]['context_actions'] = batch_traj_actions
            train_batch_dataset[i]['context_rewards'] = batch_traj_rewards

            # print(i)
            if i > 50:
                break
        
        # ipdb.set_trace()

        
        
        avg_reward_env = np.zeros(self.du)
        sum_reward_env = np.zeros(self.du)

        for env in range(200):
            
            # print("env eval for Bayes Pred", env)
            X = []
            Y = []
            
            # pulls = np.zeros(self.dim)
            batch_traj_actions = train_batch_dataset[env]['context_actions']
            batch_traj_rewards = train_batch_dataset[env]['context_rewards']

            for i in range(len(batch_traj_rewards)):
                one_traj_actions = batch_traj_actions[i]
                one_traj_rewards = batch_traj_rewards[i]

                actions_indices = np.argmax(one_traj_actions, axis=1)
                actions_arms = self.arms[actions_indices]

                for indices in range(len(one_traj_actions)):

        
                    c = np.argmax(one_traj_actions[indices])
                    sum_reward_env[c] += one_traj_rewards[indices]
        

        avg_reward_env = sum_reward_env/200
        init_cov = avg_reward_env @ avg_reward_env.T + 0.001 * np.eye(len(avg_reward_env))




        if self.tf_type == "original":
            a = self.model(self.batch)
            # ipdb.set_trace()
            self.sample = True
        elif self.tf_type == "new":
            a_next, r_pred = self.model(self.batch)
            
            
            a = r_pred 
            self.sample = False
            
        elif self.tf_type == "new_opt_a":
            a_next, r_pred, a_opt = self.model(self.batch)
            
            a = torch.zeros_like(a_next)
            a_opt_ = a_opt.cpu().detach().numpy()
            r_pred_ = r_pred.cpu().detach().numpy()
            
            size_ = a_opt_.shape[0]
            for i in range(size_):
                if np.argmax(a_opt_[i], axis=-1) == np.argmax(r_pred_[i], axis=-1):
                    a[i] = a_opt[i]
                    self.sample = False
                else:
                    a[i] = a_next[i]
                    self.sample = False
            
            
            # ipdb.set_trace()
            # epsilon = 0.1
            # if np.random.random() < epsilon:
            #     a = a_next
            #     self.sample = False
            # else:
            #     a = r_pred
            #     self.sample = False
        else:
            raise NotImplementedError
            
        a = a.cpu().detach().numpy()
        if self.batch_size == 1:
            a = a[0]

        if self.sample:
            temperature = 1.0 # 4.5
            probs = scipy.special.softmax(a/temperature, axis=-1)
            # print("probs", probs)
            action_indices = np.array([np.random.choice(np.arange(self.du), p=p) for p in probs])
        else:
            action_indices = np.argmax(a, axis=-1) + 0.1*np.sqrt(np.dot(np.dot(a, init_cov), a))

        actions = np.zeros((self.batch_size, self.du))
        actions[np.arange(self.batch_size), action_indices] = 1.0
        return actions